{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "71a329f0",
   "metadata": {},
   "source": [
    "# BYOC instruction for using LMI container on SageMaker\n",
    "In this tutorial, you will bring your own container from docker hub to SageMaker and run inference with it.\n",
    "Please make sure the following permission granted before running the notebook:\n",
    "\n",
    "- ECR Push/Pull access\n",
    "- S3 bucket push access\n",
    "- SageMaker access\n",
    "\n",
    "## Step 1: Let's bump up SageMaker and import stuff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67fa3208",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install sagemaker boto3 awscli --upgrade  --quiet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec9ac353",
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "import sagemaker\n",
    "from sagemaker import Model, serializers, deserializers\n",
    "\n",
    "role = sagemaker.get_execution_role()  # execution role for the endpoint\n",
    "sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs\n",
    "region = sess._region_name  # region name of the current SageMaker Studio environment\n",
    "account_id = sess.account_id()  # account_id of the current SageMaker Studio environment"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71542f98",
   "metadata": {},
   "source": [
    "## Step 2 pull and push the docker from Docker hub to ECR repository\n",
    "\n",
    "*Note: Please make sure you have the permission in AWS credential to push to ECR repository*\n",
    "\n",
    "This process may take a while, depends on the container size and your network bandwidth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1efb852",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%sh\n",
    "\n",
    "# The name of our container\n",
    "repo_name=djlserving-byoc\n",
    "# Target container\n",
    "target_container=\"deepjavalibrary/djl-serving:0.27.0-deepspeed\"\n",
    "\n",
    "account=$(aws sts get-caller-identity --query Account --output text)\n",
    "\n",
    "# Get the region defined in the current configuration (default to us-west-2 if none defined)\n",
    "region=$(aws configure get region)\n",
    "region=${region:-us-west-2}\n",
    "\n",
    "fullname=\"${account}.dkr.ecr.${region}.amazonaws.com/${repo_name}:latest\"\n",
    "echo \"Creating ECR repository ${fullname}\"\n",
    "\n",
    "# If the repository doesn't exist in ECR, create it.\n",
    "\n",
    "aws ecr describe-repositories --repository-names \"${repo_name}\" > /dev/null 2>&1\n",
    "\n",
    "if [ $? -ne 0 ]\n",
    "then\n",
    "    aws ecr create-repository --repository-name \"${repo_name}\" > /dev/null\n",
    "fi\n",
    "\n",
    "# Get the login command from ECR and execute it directly\n",
    "aws ecr get-login-password --region ${region} | docker login --username AWS --password-stdin \"${account}.dkr.ecr.${region}.amazonaws.com\"\n",
    "\n",
    "# Build the docker image locally with the image name and then push it to ECR\n",
    "# with the full name.\n",
    "echo \"Start pulling container: ${target_container}\"\n",
    "\n",
    "docker pull ${target_container}\n",
    "docker tag ${target_container} ${fullname}\n",
    "docker push ${fullname}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81deac79",
   "metadata": {},
   "source": [
    "## Step 3: Start preparing model artifacts\n",
    "In LMI contianer, we expect some artifacts to help setting up the model\n",
    "- serving.properties (required): Defines the model server settings\n",
    "- model.py (optional): A python file to define the core inference logic\n",
    "- requirements.txt (optional): Any additional pip wheel need to install"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b011bf5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile serving.properties\n",
    "# Start writing content here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19d6798b",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile model.py\n",
    "# Start writing content here (remove this cell if not used)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8b50a6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile requirements.txt\n",
    "# Start writing content here (remove this file if not neeed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0142973",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%sh\n",
    "mkdir mymodel\n",
    "mv serving.properties mymodel/\n",
    "# remove the following lines if not needed\n",
    "mv model.py mymodel/\n",
    "mv requirements.txt mymodel/\n",
    "tar czvf mymodel.tar.gz mymodel/\n",
    "rm -rf mymodel"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e58cf33",
   "metadata": {},
   "source": [
    "## Step 4: Start building SageMaker endpoint\n",
    "In this step, we will build SageMaker endpoint from scratch\n",
    "\n",
    "### 4.1 Upload artifact on S3 and create SageMaker model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38b1e5ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "s3_code_prefix = \"large-model-lmi/code\"\n",
    "bucket = sess.default_bucket()  # bucket to house artifacts\n",
    "code_artifact = sess.upload_data(\"mymodel.tar.gz\", bucket, s3_code_prefix)\n",
    "print(f\"S3 Code or Model tar ball uploaded to --- > {code_artifact}\")\n",
    "\n",
    "repo_name=\"djlserving-byoc\"\n",
    "image_uri = f\"{account_id}.dkr.ecr.{region}.amazonaws.com/{repo_name}:latest\"\n",
    "env = {\"HUGGINGFACE_HUB_CACHE\": \"/tmp\", \"TRANSFORMERS_CACHE\": \"/tmp\"}\n",
    "\n",
    "model = Model(image_uri=image_uri, model_data=code_artifact, env=env, role=role)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "004f39f6",
   "metadata": {},
   "source": [
    "### 4.2 Create SageMaker endpoint\n",
    "\n",
    "You need to specify the instance to use and endpoint names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e0e61cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "instance_type = \"ml.g4dn.4xlarge\"\n",
    "endpoint_name = sagemaker.utils.name_from_base(\"lmi-model\")\n",
    "\n",
    "model.deploy(initial_instance_count=1,\n",
    "             instance_type=instance_type,\n",
    "             endpoint_name=endpoint_name,\n",
    "             # container_startup_health_check_timeout=3600\n",
    "            )\n",
    "\n",
    "# our requests and responses will be in json format so we specify the serializer and the deserializer\n",
    "predictor = sagemaker.Predictor(\n",
    "    endpoint_name=endpoint_name,\n",
    "    sagemaker_session=sess,\n",
    "    serializer=serializers.JSONSerializer(),\n",
    "    deserializer=deserializers.JSONDeserializer(),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb63ee65",
   "metadata": {},
   "source": [
    "## Step 5: Test and benchmark the inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bcef095",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%timeit -n3 -r1\n",
    "predictor.predict(\n",
    "    {\"inputs\": \"Large model inference is\", \"parameters\": {}}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c1cd9042",
   "metadata": {},
   "source": [
    "## Clean up the environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d674b41",
   "metadata": {},
   "outputs": [],
   "source": [
    "sess.delete_endpoint(endpoint_name)\n",
    "sess.delete_endpoint_config(endpoint_name)\n",
    "model.delete_model()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
